#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: Xiaobo Yang
@contact: hal_42@zju.edu.cn
@software: PyCharm
@file: data_manager.py
@time: 2020/3/20 23:22
@desc:
"""
from typing import Union, Optional, Callable, List, Iterable, Tuple

import os
import os.path as osp
import shutil
import pickle
import random

import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DataLoader, Sampler, RandomSampler, SequentialSampler, BatchSampler, DistributedSampler

from alchemy_cat.data import Prefetcher, DataAuger, Dataset, read_rand_seeds
from alchemy_cat.py_tools import indent

kBatchesType = List[List[int]]


class _IdentityMapAuger(DataAuger):
    """Auger for identity mapping"""

    def build_graph(self):
        @self.graph.register(inputs=['example'], outputs=['output'])
        def identity_mapping(example):
            return example


class _EpochBatchSampler(Sampler):
    """Batch sampler for DataLoader in DataManager"""

    def __init__(self, batches, epoch_iteration, data_source):
        super(_EpochBatchSampler, self).__init__(data_source)

        self.batches = batches[epoch_iteration:]

    def __iter__(self):
        return iter(self.batches)

    def __len__(self):
        return len(self.batches)


# 拷贝自https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/worker.py，可以使得到的numpy seed与random的
# seed错开。
def _generate_state(base_seed, worker_id):
    INIT_A = 0x43b0d7e5
    MULT_A = 0x931e8875
    INIT_B = 0x8b51f9dd
    MULT_B = 0x58f38ded
    MIX_MULT_L = 0xca01f9dd
    MIX_MULT_R = 0x4973f715
    XSHIFT = 4 * 8 // 2
    MASK32 = 0xFFFFFFFF

    entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
    pool = [0] * 4

    hash_const_A = INIT_A

    def hash(value):
        nonlocal hash_const_A
        value = (value ^ hash_const_A) & MASK32
        hash_const_A = (hash_const_A * MULT_A) & MASK32
        value = (value * hash_const_A) & MASK32
        value = (value ^ (value >> XSHIFT)) & MASK32
        return value

    def mix(x, y):
        result_x = (MIX_MULT_L * x) & MASK32
        result_y = (MIX_MULT_R * y) & MASK32
        result = (result_x - result_y) & MASK32
        result = (result ^ (result >> XSHIFT)) & MASK32
        return result

    # Add in the entropy to the pool.
    for i in range(len(pool)):
        pool[i] = hash(entropy[i])

    # Mix all bits together so late bits can affect earlier bits.
    for i_src in range(len(pool)):
        for i_dst in range(len(pool)):
            if i_src != i_dst:
                pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))

    hash_const_B = INIT_B
    state = []
    for i_dst in range(4):
        data_val = pool[i_dst]
        data_val = (data_val ^ hash_const_B) & MASK32
        hash_const_B = (hash_const_B * MULT_B) & MASK32
        data_val = (data_val * hash_const_B) & MASK32
        data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
        state.append(data_val)
    return state


def _check_sampler(sampler):
    if not hasattr(sampler, '__len__'):
        raise ValueError(f"sampler <{sampler.__class__}>: {sampler} should have attribute '__len__'")
    length = len(sampler)
    if not isinstance(length, int) or length <= 0:
        raise ValueError(f"len(sampler): {length} should be integer > 0")


def _check_batches_read(batches, batch_num, batch_size, drop_last):
    if not isinstance(batches, list) or not isinstance(batches[0], list) or not isinstance(batches[0][0], int):
        raise RuntimeError(f"batches read {batches} should be List[List[int]")

    if len(batches) < batch_num:
        raise RuntimeError(f"len(batches read) {len(batches)} != DataManager's batch num {batch_num}")

    if (len(batches) > 1 or drop_last) and len(batches[0]) != batch_size:
        raise RuntimeError(f"batch size of batches read {len(batches[0])} != "
                           f"DataManager's batch size {batch_size}")


def _check_batches_uniqueness(batches: kBatchesType, auger: DataAuger):
    indices = DataManager.batches2indices(batches)
    if len(indices) != len(set(indices)) and len(auger.rand_nodes) > 0:
        raise RuntimeError(f"batches generated by batch sampler can't have indices repeated when there is rand_nodes in"
                           f"data_auger's graph")


class DataManager(object):
    """A dataloader which can record and recover the process of loading"""

    __initialized: bool = False

    def __init__(self, dataset: Union[Dataset, TorchDataset, None]=None, data_auger: Optional[DataAuger]=None,
                 log_dir: str='.', is_prefetch: bool=False, log_rand_seeds: bool=True,
                 batch_size: int=1, shuffle: bool=False, sampler: Optional[Sampler]=None,
                 batch_sampler: Optional[Sampler]=None, num_workers: int=0, collate_fn: Optional[Callable]=None,
                 pin_memory: bool=False, drop_last: bool=False, timeout: int=0,
                 worker_init_fn: Optional[Callable]=None,
                 generator: Optional[torch.Generator]=None, prefetch_factor: int=2):
        """ A dataloader which can record and recover the process of loading

        Args:
            dataset: Dataset to be loaded. Can be gotten from data_auger
            data_auger: DataAuger to be loaded. If None, will create an identity mapping auger with dataset.
            log_dir: Dictionary where DataManager save its log
            is_prefetch: If True, data loader iter will be wrapped by Prefetcher, which can overlap the data transfer
                and calculating on GPU. Only usable when cuda is available. (Default: False)
            log_rand_seeds: If True, rand seeds will be recorded. (Default: True)
            batch_size: Same to param for torch.data.DataLoader
            shuffle: Same to param for torch.data.DataLoader
            sampler: Same to param for torch.data.DataLoader
            batch_sampler: Same to param for torch.data.DataLoader
            num_workers: Same to param for torch.data.DataLoader
            collate_fn: Same to param for torch.data.DataLoader
            pin_memory: Same to param for torch.data.DataLoader
            drop_last: Same to param for torch.data.DataLoader
            timeout: Same to param for torch.data.DataLoader
            worker_init_fn: Same to param for torch.data.DataLoader
            generator: Same to param for torch.data.DataLoader
            prefetch_factor: Same to param for torch.data.DataLoader

        See Also:
            torch.utils.data.DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
        """
        if dataset is None and data_auger is None:
            raise ValueError(f"dataset and data_auger can't be None at the same time")

        if dataset is not None and data_auger is not None:
            if data_auger.dataset is not dataset:
                raise ValueError(f"data_auger.dataset {data_auger.dataset} should be dataset {dataset}")

        if dataset is None:
            dataset = data_auger.dataset
        elif data_auger is None:
            data_auger = _IdentityMapAuger(dataset, slim=True)

        self._dataset = dataset
        self._data_auger = data_auger

        self.log_dir = log_dir
        self.is_prefetch = is_prefetch
        self.log_rand_seeds = log_rand_seeds

        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.collect_fn = collate_fn
        self.generator = generator
        self.prefetch_factor = prefetch_factor

        self.shuffle = shuffle
        self.sampler = sampler
        self.drop_last = drop_last
        self.batch_size = batch_size
        self.batch_sampler = batch_sampler

        if self.batch_sampler is not None:
            if self.shuffle or self.sampler is not None or self.drop_last or self.batch_size != 1:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            _check_sampler(self.batch_sampler)
            self.batch_size = len(next(iter(self.batch_sampler)))
            self.shuffle = None  # UnKnown
            self.drop_last = None  # UnKnown
        else:
            if self.sampler is not None:
                if self.shuffle:
                    raise ValueError('sampler option is mutually exclusive with shuffle')
                _check_sampler(self.sampler)
                self.shuffle = None  # UnKnown
            elif self.shuffle:
                self.sampler = RandomSampler(self.data_source, generator=self.generator)
            else:
                self.sampler = SequentialSampler(self.data_source)

            self.batch_sampler = BatchSampler(self.sampler, batch_size, drop_last)

        self._epoch: int = -1
        self.epoch_batches: Optional[kBatchesType] = None
        self.epoch_loader: Optional[DataLoader] = None
        self.epoch_iter: Optional[Iterable] = None

        def worker_init_fn_(worker_id):
            seed = torch.initial_seed()
            random.seed(seed)
            np_seed = _generate_state(seed - worker_id, worker_id)
            np.random.seed(np_seed)
            if worker_init_fn is not None:
                worker_init_fn(worker_id)

        self.worker_init_fn = worker_init_fn_

        self.__initialized: bool = True

    @property
    def dataset(self):
        return self._dataset

    @property
    def data_auger(self):
        return self._data_auger

    @property
    def auger(self):
        return self._data_auger

    @property
    def data_source(self):
        return self._data_auger

    @property
    def log_parent_dir(self):
        if dist.is_initialized():
            log_parent_dir = osp.join(self.log_dir, '.data_manager_log', f'rank{dist.get_rank()}')
        else:
            log_parent_dir = osp.join(self.log_dir, '.data_manager_log')
        return log_parent_dir

    def epoch_log_dir(self, epoch):
        return osp.join(self.log_parent_dir, f"epoch-{epoch}")

    def batches_log_file(self, epoch):
        return osp.join(self.epoch_log_dir(epoch), 'batches.pkl')

    def rand_seed_log_file(self, epoch):
        return osp.join(self.epoch_log_dir(epoch), 'rand_seed_log')

    def __setattr__(self, attr, val):
        if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', '_dataset'
                                           '_data_auger', 'shuffle'):

            raise ValueError('{} attribute should not be set after {} is '
                             'initialized'.format(attr, self.__class__.__name__))

        super(DataManager, self).__setattr__(attr, val)

    @staticmethod
    def batches2indices(batches: kBatchesType):
        """Flat batches to indices"""
        return [index for batch in batches for index in batch]

    @property
    def batch_num(self):
        """Number of batches in each epoch"""
        return len(self.batch_sampler)

    def __len__(self):
        """Number of batches in each epoch"""
        return self.batch_num

    @property
    def epoch(self):
        """DataManager's current epoch"""
        return self._epoch

    @property
    def epoch_iteration(self):
        """DataManager's current epoch iteration. Need PyTorch >= 1.4"""
        return self.epoch_iter._num_yielded

    @property
    def epoch_indices(self):
        return self.batches2indices(self.epoch_batches) if self.epoch_batches is not None else None

    def read_batches(self, epoch: int) -> Optional[kBatchesType]:
        """Return batches of specified epoch

        Args:
            epoch: epoch num

        Returns:
            If epoch record exits, return batches. Else return None.
        """
        if not osp.isfile(self.batches_log_file(epoch)):
            return None

        with open(self.batches_log_file(epoch), 'rb') as batches_log:
            batches: kBatchesType = pickle.load(batches_log)

        _check_batches_read(batches, self.batch_num, self.batch_size, self.drop_last)
        return batches

    def save_batches(self, epoch, batches):
        with open(self.batches_log_file(epoch), 'wb') as batches_log:
            pickle.dump(batches, batches_log)

    def read_rand_seeds(self, epoch, idx):
        return read_rand_seeds(self.rand_seed_log_file(epoch), idx, create=False)

    def locate_epoch(self, iteration: int) -> Tuple[int, int]:
        """Return epoch and epoch_iteration according to total iteration"""
        return iteration // self.batch_num, iteration % self.batch_num

    def backtrace(self, iteration: Optional[int] = None,
                  epoch: Optional[int] = None, epoch_iteration: Optional[int] = None,
                  epoch_loc: Optional[int] = None, iteration_loc: Optional[int] = None) -> dict:
        """Return the backtrace of specified sample

        Args:
            iteration: The total iteration where the sample at
            epoch: The epoch where the sample at
            epoch_iteration: The epoch iteration where the sample at
            epoch_loc: Number of sample in its epoch's indices
            iteration_loc: Number of sample in its iteration's indices

        Returns:
            A dict 'ret' with keys and values:
                ret['dataset_idx']: sample's dataset index
                ret['data_auger_idx']: sample's data_auger index
                ret['node_indices']: sample's data_auger.graph's node indices
                ret['rand_seeds']: sample's data_auger.graph's rand seeds
                ret['auger_input']: origin sample from dataset
                ret['auger_output']: data_auger's output of sample
                ret['graph_data']: sample's data_auger's calculate data
        """
        if iteration is not None:
            epoch_, epoch_iteration_ = self.locate_epoch(iteration)

            if epoch is not None and epoch != epoch_:
                raise ValueError(f"epoch calculated from iteration {epoch_} != input epoch {epoch}")
            else:
                epoch = epoch_

            if epoch_iteration is not None and epoch_iteration != epoch_iteration_:
                raise ValueError(f"epoch iteration calculated from iteration {epoch_iteration_} != "
                                 f"input epoch_iteration {epoch_iteration}")
            else:
                epoch_iteration = epoch_iteration_

        if epoch is None:
            raise ValueError(f"Either iteration or epoch needed for locating sample")

        if epoch_loc is not None:
            epoch_iteration_, iteration_loc_ = epoch_loc // self.batch_size, epoch_loc % self.batch_size

            if epoch_iteration is not None and epoch_iteration != epoch_iteration_:
                raise ValueError(f"epoch iteration calculated from epoch_loc {epoch_iteration_} != "
                                 f"input epoch_iteration {epoch_iteration}")
            else:
                epoch_iteration = epoch_iteration_

            if iteration_loc is not None and iteration_loc != iteration_loc_:
                raise ValueError(f"iteration location calculated from epoch_loc {iteration_loc_} != "
                                 f"input iteration_loc {iteration_loc}")
            else:
                iteration_loc = iteration_loc_

        if epoch_iteration is None or iteration_loc is None:
            raise ValueError(f"epoch_iteration or iteration needed to get epoch iteration."
                             f"And epoch_loc or iteration_loc needed to get iteration location")

        batches = self.read_batches(epoch)
        if batches is None:
            raise RuntimeError(f"Can't find batches log")

        auger_idx = batches[epoch_iteration][iteration_loc]

        rand_seeds = self.read_rand_seeds(epoch, auger_idx)
        if rand_seeds is None:
            raise RuntimeError(f"Can't find rand seeds log")

        self.auger.load_rand_seeds(rand_seeds)

        ret = {}
        ret['data_auger_idx'] = auger_idx
        ret['dataset_idx'], ret['node_indices'] = self.auger.calculate_indices(auger_idx)
        ret['rand_seeds'] = rand_seeds

        ret['auger_input'] = self.dataset[ret['dataset_idx']]
        ret['auger_output'] = self.auger[auger_idx]
        ret['graph_data'] = self.auger.graph.data

        return ret

    def move_epoch_to(self, epoch: int, epoch_iteration: int, recreate: bool=False):
        """Move current epoch to (epoch, epoch_iteration).

        Args:
            epoch: epoch moved to
            epoch_iteration: epoch iteration moved to
            recreate: If True, the origin record of epoch moved to will be recreated.
        """
        if isinstance(self.sampler, DistributedSampler):
            self.sampler.set_epoch(epoch)
        self._epoch = epoch

        if recreate and osp.isdir(self.epoch_log_dir(epoch)):
            shutil.rmtree(self.epoch_log_dir(epoch))

        if not osp.isdir(self.epoch_log_dir(epoch)):
            os.makedirs(self.epoch_log_dir(epoch))

        batches = self.read_batches(epoch)
        if batches is None:
            batches = list(iter(self.batch_sampler))
            _check_batches_uniqueness(batches, self.auger)
            self.save_batches(epoch, batches)
        self.epoch_batches = batches

        self.auger.rand_seed_log = self.rand_seed_log_file(epoch) if self.log_rand_seeds else None

        epoch_batch_sampler = _EpochBatchSampler(batches, epoch_iteration, self.data_source)

        if self.collect_fn is None:
            # The default value of collate_fn changing with torch version
            self.epoch_loader = DataLoader(self.data_source, batch_sampler=epoch_batch_sampler,
                                           pin_memory=self.pin_memory,
                                           num_workers=self.num_workers, timeout=self.timeout,
                                           worker_init_fn=self.worker_init_fn,
                                           generator=self.generator, prefetch_factor=self.prefetch_factor)
        else:
            self.epoch_loader = DataLoader(self.data_source, batch_sampler=epoch_batch_sampler,
                                           pin_memory=self.pin_memory, collate_fn=self.collect_fn,
                                           num_workers=self.num_workers, timeout=self.timeout,
                                           worker_init_fn=self.worker_init_fn,
                                           generator=self.generator, prefetch_factor=self.prefetch_factor)

        self.epoch_iter = iter(self.epoch_loader) if not self.is_prefetch else Prefetcher(iter(self.epoch_loader))

    def start_epoch(self, iteration: Optional[int]=None, recreate: bool=False) -> Iterable:
        """Start an epoch.

        Args:
            iteration: If not None, start epoch at specified iteration. Else start epoch behind current epoch.
                (Default: None)
            recreate: If True, the origin record of epoch starting at moved to will be recreated.

        Returns:
            Current epoch's iter(DataLoader)
        """
        if iteration is not None:
            epoch, epoch_iteration = self.locate_epoch(iteration)
            self.move_epoch_to(epoch, epoch_iteration, recreate)
        else:
            self.move_epoch_to(self.epoch + 1, 0, recreate)

        return self.epoch_iter

    def next_batch(self, recreate: bool=False):
        """Get next batch from DataManager. This function can work infinitely without raising StopIteration.

        Args:
            recreate: If True, the origin record of new epoch moved to will be recreated.

        Returns:
            next batch
        """
        if self.epoch == -1:
            self.start_epoch(recreate=recreate)

        try:
            batch = next(self.epoch_iter)
        except StopIteration:
            self.start_epoch(recreate=recreate)
            batch = next(self.epoch_iter)

        return batch

    def __iter__(self):
        return self.epoch_iter

    def __repr__(self):
        return f"DataManager <{self.__class__}>: \n" \
               + indent(f'epoch_batch_num: {len(self)}') + '\n' \
               + indent(f'batch_size: {self.batch_size}') + '\n' \
               + indent(f'shuffle: {self.shuffle}') + '\n' \
               + indent(f'drop_last: {self.drop_last}') + '\n' \
               + indent(f'num_workers: {self.num_workers}') + '\n' \
               + indent(f'pin_memory: {self.pin_memory}') + '\n' \
               + indent(f'is_prefetch: {self.is_prefetch}') + '\n' \
               + indent(f'dataset: {self.dataset}') + '\n' \
               + indent(f'data_auger: {self.data_auger}')
